#%% 
import shapreg  # https://github.com/iancovert/shapley-regression
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--test_training_speed', type=bool, default=False)
args = parser.parse_args()
#%% 
# Load and split data
file = 'Taiwan_data_ENG_95.csv'
data = pd.read_csv(file, encoding='utf-8')

#%% 
Y = np.array(data['Flag'])
X = np.array(data.drop(['Flag'], axis=1))
X_train, X_test, Y_train, Y_test = train_test_split(
    X, Y, test_size=0.2, random_state=70)
X_train, X_val, Y_train, Y_val = train_test_split(
    X_train, Y_train, test_size=0.2, random_state=42)

# Data scaling
num_features = X_train.shape[1]
ss = StandardScaler()
ss.fit(X_train)
X_train = ss.transform(X_train)
X_val = ss.transform(X_val)
X_test = ss.transform(X_test)

#%% Train Model
import pickle
import os.path
import lightgbm as lgb
from lightgbm import log_evaluation, early_stopping

#%% 
if os.path.isfile('bank model.pkl'):
    print('Loading saved model')
    with open('bank model.pkl', 'rb') as f:
        model = pickle.load(f)

else:
    # Setup
    params = {
        "max_bin": 512,
        "learning_rate": 0.01,
        "boosting_type": "gbdt",
        "objective": "binary",
        "metric": "binary_logloss",
        "num_leaves": 10,
        "verbose": -1,
        "min_data": 100,
        "boost_from_average": True
    }

    # More setup
    d_train = lgb.Dataset(X_train, label=Y_train)
    d_val = lgb.Dataset(X_val, label=Y_val)
    callbacks = [log_evaluation(period=1000), early_stopping(stopping_rounds=50)]
    # Train model
    model = lgb.train(params, d_train, 10000, valid_sets=[d_val],
                      callbacks=callbacks)
    
    # Save model
    with open('bank model.pkl', 'wb') as f:
        pickle.dump(model, f)
#%% Train surrogate
import torch
import torch.nn as nn
from fastshap.utils import MaskLayer1d
from fastshap import Surrogate, KLDivLoss

# Select device
device = torch.device('cuda')

#%% 
# Check for model
if os.path.isfile('bank surrogate.pt'):
    print('Loading saved surrogate model')
    surr = torch.load('bank surrogate.pt').to(device)
    surrogate = Surrogate(surr, num_features)

else:
    # Create surrogate model
    surr = nn.Sequential(
        MaskLayer1d(value=0, append=True),
        nn.Linear(2 * num_features, 256),
        nn.ELU(inplace=True),
        nn.Linear(256, 512),
        nn.ELU(inplace=True),
        nn.Linear(512, 512), 
        nn.ELU(inplace=True),
        nn.Linear(512, 256),
        nn.ELU(inplace=True),
        nn.Linear(256, 128),
        nn.ELU(inplace=True),
        nn.Linear(128, 2)
        ).to(device)

    # Set up surrogate object
    surrogate = Surrogate(surr, num_features)

    # Set up original model
    def original_model(x):
        pred = model.predict(x.cpu().numpy())
        pred = np.stack([1 - pred, pred]).T
        return torch.tensor(pred, dtype=torch.float32, device=x.device)

    # Train
    surrogate.train_original_model(
        X_train,
        X_val,
        model,
        batch_size=64,
        max_epochs=150,
        lr=1e-4,
        loss_fn=KLDivLoss(),
        validation_samples=10,
        validation_batch_size=100,
        verbose=True)

    # Save surrogate
    surr.cpu()
    torch.save(surr, 'bank surrogate.pt')
    surr.to(device)

#%% Train FastSHAP
#%% fastshap
from simshap.fastshap_plus import FastSHAP
import time
# Check for model
if os.path.isfile('bank fastshap.pt'):
    print('Loading saved explainer model')
    explainer_fastshap = torch.load('bank fastshap.pt').to(device)
    fastshap = FastSHAP(explainer_fastshap, surrogate, normalization='additive',
                        link=nn.Identity())

else:
    # Create explainer model
    explainer_fastshap = nn.Sequential(
        nn.Linear(num_features, 128),
        nn.ReLU(inplace=True),
        nn.Linear(128, 128),
        nn.ReLU(inplace=True),
        nn.Linear(128, 2 * num_features)).to(device)

    # Set up FastSHAP object
    fastshap = FastSHAP(explainer_fastshap, surrogate, 
                        link=nn.Identity(), normalization='additive')
    # Train
    if args.test_training_speed:
        start = time.time()
    fastshap.train(
        X_train,
        X_val[:100],
        batch_size=32,
        num_samples=32,
        max_epochs=200,
        validation_samples=128,
        verbose=True)
    if args.test_training_speed:
        print('fastshap train time:', time.time() - start)
    # Save explainer
    explainer_fastshap.cpu()
    torch.save(explainer_fastshap, 'bank fastshap.pt')
    explainer_fastshap.to(device)
#%% Train simshap

from simshap.simshap_sampling import SimSHAPSampling
import sys
sys.path.append('..')
from models import SimSHAPTabular
import time
# Check for model
if os.path.isfile('bank simshap.pt'):
    print('Loading saved explainer model')
    explainer = torch.load('bank simshap.pt').to(device)
    simshap = SimSHAPSampling(explainer=explainer, imputer=surrogate, device=device)

else:
    # Create explainer model
    explainer = SimSHAPTabular(in_dim=num_features, hidden_dim=512, out_dim=2).to(device)

    # Set up FastSHAP object
    simshap = SimSHAPSampling(explainer=explainer, imputer=surrogate, device=device)
    # Train
    if args.test_training_speed:
        start = time.time()
    simshap.train(
        X_train,
        X_val[:100],
        batch_size=1024,
        num_samples=64,
        max_epochs=1000,
        lr=7e-4,  
        bar=False,
        validation_samples=128,
        verbose=True, 
        lookback=10,
        lr_factor=0.5)
    if args.test_training_speed:
        print('simshap train time:', time.time() - start)
    # Save explainer
    explainer.cpu()
    torch.save(explainer, 'bank simshap.pt')
    explainer.to(device)

#%% Compare with KernelSHAP
import matplotlib.pyplot as plt
# Setup for KernelSHAP
def imputer(x, S):
    x = torch.tensor(x, dtype=torch.float32, device=device)
    S = torch.tensor(S, dtype=torch.float32, device=device)
    pred = surrogate(x, S)
    return pred.cpu().data.numpy()

# Select example
np.random.seed(20)
ind = np.random.choice(len(X_test))
x = X_test[ind:ind+1]
y = int(Y_test[ind])

# Run evoshap
evoshap_values = simshap.shap_values(x)[0].transpose(1,0)
fastshap_values = fastshap.shap_values(x)[0]
# Run KernelSHAP to convergence
game = shapreg.games.PredictionGame(imputer, x)
shap_values, all_results = shapreg.shapley.ShapleyRegression(
    game, batch_size=32, paired_sampling=False, detect_convergence=True,
    bar=True, return_all=True)

# Create figure
plt.figure(figsize=(9, 5.5))

plt_num_features = 10
# Bar chart
width = 0.75
kernelshap_iters = 128
plt.bar(np.arange(plt_num_features) - width / 2, shap_values.values[:10, y],
        width / 4, label='True SHAP values', color='tab:gray')
plt.bar(np.arange(plt_num_features) - width / 4, evoshap_values[:10, y],
        width / 4, label='EvoSHAP', color='tab:green')
plt.bar(np.arange(plt_num_features),
        fastshap_values[:10, y],
        width / 4, label='fastSHAP', color='tab:blue')
plt.bar(np.arange(plt_num_features) + width / 4,
        all_results['values'][list(all_results['iters']).index(kernelshap_iters)][:10, y],
        width / 4, label='KernelSHAP @ {}'.format(kernelshap_iters), color='tab:red')

# Annotations
plt.legend(fontsize=16)
plt.tick_params(labelsize=14)
plt.ylabel('SHAP Values', fontsize=16)
plt.title('Bank Explanation Example', fontsize=18)

plt.tight_layout()
plt.savefig('bank simshap.png')
plt.show()